"""
 @Author: zzgsg
 @time: 2024/9/28 20:32
"""

# Figure 5. Comparation between synthetic and the retrieved NO2 concentration profiles.
import numpy as np
import subprocess
import shutil
import os
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import matplotlib
matplotlib.rcParams['mathtext.default'] = 'regular'

def e_profile(c, H, z):
    return c/H * np.exp(-z/H)

def box_profile(c, H, z):
    proflie = np.ones_like(z) * c/H
    proflie[z>H] = 0
    return proflie

def layered_profile(c, H1, H2, z):
    proflie = np.ones_like(z) * c/(H2-H1)
    proflie[z>H2] = 0
    proflie[z<=H1] = 0
    return proflie

def gauss_profile(c, x0, sigma, z):
    # 设置常数
    A = c/(sigma*np.sqrt(2*np.pi))  # 幅度
    return A * np.exp(-(z - x0) ** 2 / (2 * sigma ** 2))

def Generate_Profile(z1, z2, z3):
    aero_dict = {}
    aero_dict['aero0'] = 1e-10 * np.ones_like(z1)
    aero_dict['aero1'] = e_profile(0.25, 1, z1)
    aero_dict['aero2'] = e_profile(1, 1, z1)
    aero_dict['aero3'] = e_profile(0.25, 0.25, z1)
    aero_dict['aero4'] = box_profile(0.1, 0.2, z1)
    aero_dict['aero5'] = box_profile(0.5, 1,z1)
    aero_dict['aero6'] = gauss_profile(0.25, 1, 0.3, z1)
    aero_dict['aero7'] = box_profile(2, 0.2,z1)
    aero_dict['aero8'] = layered_profile(5, 1, 1.5, z1)
    aero_dict['aero9'] = layered_profile(5, 5, 5.5, z1)

    gas_dict = {}
    gas_dict['gas0'] = e_profile(1, 1, z2)
    gas_dict['gas1'] = e_profile(1, 0.25, z2)
    gas_dict['gas2'] = box_profile(1, 0.3, z2)
    gas_dict['gas3'] = box_profile(1, 1,z2)
    gas_dict['gas4'] = gauss_profile(1, 1.5, 0.3, z2)
    gas_dict['gas5'] = layered_profile(1, 1, 2, z2)

    return aero_dict, gas_dict
z1=np.r_[np.arange(0., 6.01, 0.1), np.arange(7, 10, 1), np.arange(10, 41, 6)]
z_grid_middle = np.r_[np.arange(0.01, 4, 0.01), np.arange(4, 10, 1), np.arange(10, 41, 5)]
z2 = (z_grid_middle[:-1] + z_grid_middle[1:]) / 2
z_grid = []
with open('RTM_grid_bak.dat', 'r') as f:
    for i in f:
        if i != '\n':
            z_grid.append(float(i.strip('\n')))
z3 = np.array(z_grid)
aero_dict, gas_dict = Generate_Profile(z1, z1, z3)

fig3, axs = plt.subplots(10, 6, figsize=(30, 35))
plt.subplots_adjust(top=0.95, bottom=0.14, right=0.97, left=0.11, hspace=0., wspace=0.)
plt.rcParams['font.size'] = 40
set_max = [12, 30, 45, 21, 21, 21]
for i in range(10):
    for num_j in range(6):
        ax3 = axs.flatten()[num_j + i * 6]
        ax3.spines['bottom'].set_linewidth(4)
        ax3.spines['left'].set_linewidth(4)
        ax3.spines['right'].set_linewidth(4)
        ax3.spines['top'].set_linewidth(4)
        ax3.set_xlim([0, set_max[num_j]])
        ax3.set_xticks([0, set_max[num_j] / 3, set_max[num_j] * 2 / 3])
        ax3.grid(axis='x', linestyle='-.', linewidth=3)
        plt.setp(ax3.get_xticklabels(), fontsize=30)
        plt.setp(ax3.get_yticklabels(), fontsize=30)

        prof = pd.read_csv('synthetic_retrieved_tracegas.dat', sep='\t').values.T[(1 + 9 * (i*6+num_j)):(10 + +9 * (i*6+num_j))]
        z_grid = pd.read_csv('synthetic_retrieved_tracegas.dat', sep='\t').values.T[0]
        for j in range(9):
            ax3.plot(prof[j], z_grid, label='GEOM{}'.format(j), linewidth=3)

        a = np.c_[gas_dict['gas{}'.format(num_j)] * 10, gas_dict['gas{}'.format(num_j)] * 10].flatten()
        b = np.c_[z1[:-1], z1[1:]].flatten()
        ax3.plot(a[:-2], b, c='k', label='REAL', linewidth=3)

        ax3.set_yticks([0, 1, 2, 3])
        ax3.set_ylim([0, 4])
        if num_j != 0:
            ax3.set_yticks([])
            plt.setp(ax3.get_yticklabels(), visible=False)
        else:
            ax3.set_ylabel('AERO{}'.format(i), fontsize=40, labelpad=12)
        if i != 9:
            # ax3.set_xticks([])
            plt.setp(ax3.get_xticklabels(), visible=False)
        else:
            ax3.set_xlabel('TG{}'.format(num_j), fontsize=40, labelpad=12)

ax3.legend(loc='lower center', bbox_to_anchor=(0.5, 0), bbox_transform=plt.gcf().transFigure, ncol=5)

ax3 = plt.axes([0, 0, 1, 1])
ax3.axis('off')
ax3.annotate('$NO_2$ concentration [$10^{11}$ molec $cm^{-3}$]', xy=(0.54, 0.09), xycoords='axes fraction',
             horizontalalignment='center', verticalalignment='center', fontsize=40)
ax3.annotate('Altitude [km]', xy=(0.03, 0.545), xycoords='axes fraction', rotation=90,
             horizontalalignment='center', verticalalignment='center', fontsize=40)
plt.savefig('temp/Figure5.eps')
plt.savefig('temp/Figure5.png')
plt.close()